import torch

path = "../checkpoint/"

featnet_state = torch.load(path+"FeatNet_2000epoch.pth")
featnet_leye = featnet_state['featnet_leye']
featnet_reye = featnet_state['featnet_reye']
featnet_mouth = featnet_state['featnet_mouth']
featnet_nose = featnet_state['featnet_nose']
featnet_jaw = featnet_state['featnet_jaw']

leye_idx = [39, 17, 18, 19, 20, 21, 36, 37, 38, 40, 41] 
reye_idx = [42, 22, 23, 24, 25, 26, 43, 44, 45, 46, 47]
mouth_idx = [51, 48, 49, 50, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67]
nose_idx = [30, 27, 28, 29, 31, 32, 33, 34, 35]
jaw_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]


featnet_prior = torch.load(path+"FeatNet_prior.pth")
leye_z_feat = torch.cat((featnet_prior['leye_z_mean'][8].view(-1,128), 
                         featnet_prior['leye_z_mean'][:8], 
                         featnet_prior['leye_z_mean'][9:]), dim=0)
reye_z_feat = torch.cat((featnet_prior['reye_z_mean'][5].view(-1,128), 
                         featnet_prior['reye_z_mean'][:5], 
                         featnet_prior['reye_z_mean'][6:]), dim=0)
mouth_z_feat = torch.cat((featnet_prior['mouth_z_mean'][3].view(-1,128), 
                         featnet_prior['mouth_z_mean'][:3], 
                         featnet_prior['mouth_z_mean'][4:]), dim=0)
nose_z_feat = torch.cat((featnet_prior['nose_z_mean'][3].view(-1,128), 
                         featnet_prior['nose_z_mean'][:3], 
                         featnet_prior['nose_z_mean'][4:]), dim=0)
jaw_z_feat = featnet_prior['jaw_z_mean']


coordnet_state = torch.load(path+"CoordNet_2000epoch.pth")
coordnet = coordnet_state['coordnet']

coordnet_prior = torch.load(path+"CoordNet_prior.pth")
leye_z_coord = coordnet_prior['landmark_z_mean'][leye_idx]
reye_z_coord = coordnet_prior['landmark_z_mean'][reye_idx]
mouth_z_coord = coordnet_prior['landmark_z_mean'][mouth_idx]
nose_z_coord = coordnet_prior['landmark_z_mean'][nose_idx]
jaw_z_coord = coordnet_prior['landmark_z_mean'][jaw_idx]


state = {
    'featnet_leye' : featnet_leye,
    'featnet_reye' : featnet_reye,
    'featnet_mouth' : featnet_mouth,
    'featnet_nose' : featnet_nose,
    'featnet_jaw' : featnet_jaw,
    'coordnet' : coordnet,
    'leye_z_ft' : leye_z_feat,
    'reye_z_ft' : reye_z_feat,
    'mouth_z_ft' : mouth_z_feat,
    'nose_z_ft' : nose_z_feat,
    'jaw_z_ft' : jaw_z_feat,
    'leye_z_cd' : leye_z_coord,
    'reye_z_cd' : reye_z_coord,
    'mouth_z_cd' : mouth_z_coord,
    'nose_z_cd' : nose_z_coord,
    'jaw_z_cd' : jaw_z_coord,
    }

torch.save(state, path+"300W_state.pth")

